import torch

def OrthProj(X):

    n, d = X.size()

    if n >= d:
        P, S, Q = torch.svd(X)
        R = torch.mm(P, Q.t())
        nuclear_norm_X = torch.sum(S)
    else:
        raise ValueError('n >= d')
        P, S, Q = torch.svd(X.t())
        R = torch.mm(P, Q.t())
        R = R.t()
        nuclear_norm_X = torch.sum(S)

    return R, nuclear_norm_X

